{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "0df75687",
   "metadata": {
    "cellId": "jziuwa87tkxdnpvgjqd3q",
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Requirement already satisfied: pycocotools in /home/jupyter/.local/lib/python3.8/site-packages (2.0.5)\n",
      "Requirement already satisfied: matplotlib>=2.1.0 in /kernel/lib/python3.8/site-packages (from pycocotools) (3.3.3)\n",
      "Requirement already satisfied: numpy in /kernel/fallback/lib/python3.8/site-packages (from pycocotools) (1.19.4)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (2.4.7)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (9.2.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (1.4.4)\n",
      "Requirement already satisfied: python-dateutil>=2.1 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (2.8.2)\n",
      "Requirement already satisfied: cycler>=0.10 in /kernel/lib/python3.8/site-packages (from matplotlib>=2.1.0->pycocotools) (0.11.0)\n",
      "Requirement already satisfied: six>=1.5 in /kernel/lib/python3.8/site-packages (from python-dateutil>=2.1->matplotlib>=2.1.0->pycocotools) (1.16.0)\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (4.50.0)\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Requirement already satisfied: torchvision in /usr/local/lib/python3.8/dist-packages (0.10.1+cu111)\n",
      "Collecting torchvision\n",
      "  Downloading torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl (19.1 MB)\n",
      "     |████████████████████████████████| 19.1 MB 1.7 MB/s            \n",
      "\u001b[?25hRequirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torchvision) (3.7.4.3)\n",
      "Collecting torch==1.12.1\n",
      "  Downloading torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl (776.3 MB)\n",
      "     |████████████████████████████████| 776.3 MB 372 bytes/s           \n",
      "\u001b[?25hRequirement already satisfied: requests in /kernel/lib/python3.8/site-packages (from torchvision) (2.25.1)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /kernel/lib/python3.8/site-packages (from torchvision) (9.2.0)\n",
      "Requirement already satisfied: numpy in /kernel/fallback/lib/python3.8/site-packages (from torchvision) (1.19.4)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (2.10)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (1.26.12)\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (4.0.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /kernel/lib/python3.8/site-packages (from requests->torchvision) (2022.9.24)\n",
      "Installing collected packages: torch, torchvision\n",
      "\u001b[33m  WARNING: The scripts convert-caffe2-to-onnx, convert-onnx-to-caffe2 and torchrun are installed in '/home/jupyter/.local/bin' which is not on PATH.\n",
      "  Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.\u001b[0m\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "torchaudio 0.9.1 requires torch==1.9.1, but you have torch 1.12.1 which is incompatible.\u001b[0m\n",
      "Successfully installed torch-1.12.1 torchvision-0.13.1\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
      "Defaulting to user installation because normal site-packages is not writeable\n",
      "Requirement already satisfied: torch in /home/jupyter/.local/lib/python3.8/site-packages (1.12.1)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (3.7.4.3)\n",
      "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.3 is available.\n",
      "You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# %pip install pycocotools\n",
    "# %pip install tqdm\n",
    "# %pip install torchvision -U\n",
    "# %pip install torch -U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e8527c6e",
   "metadata": {
    "cellId": "ech9jthm9zrjuq7fn2bs"
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader\n",
    "from masks_for_mask_r_cnn_dataset import MasksForMaskRCNNDataset\n",
    "from custom_segmentation_transforms import Compose, ToTensor, RandomHorizontalFlip\n",
    "from image_utils import show_image, build_box_masks, merge_image_and_masks_boxes, merge_masks_with_colors\n",
    "from torch.utils.data import random_split\n",
    "from maskrcnn_model import build_maskrsnn_model\n",
    "from metrics import accumulate_metrics, compute_metrics, coco_metric_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "57636714",
   "metadata": {
    "cellId": "zamp0lyvjrjlqyekmhu2am"
   },
   "outputs": [],
   "source": [
    "ROOT = '/home/jupyter/mnt/s3'\n",
    "DS_ROOT = f'{ROOT}/pennfudanped'\n",
    "DS_MASKS = f'{DS_ROOT}/PedMasks'\n",
    "DS_IMAGES = f'{DS_ROOT}/PNGImages'\n",
    "\n",
    "PARAMS = {\n",
    "    'batch_size': 1,\n",
    "    'epochs': 1,\n",
    "    'lr': 0.001,\n",
    "    'momentum': 0.9,\n",
    "    'weight_decay': 0.0005,\n",
    "    'step_size': 3,\n",
    "    'gamma': 0.1,\n",
    "    'pretrained': torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.DEFAULT\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "02cfedc8",
   "metadata": {
    "cellId": "0opprkt2dlxsvv4kuogazah"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_of_targets = 1\n"
     ]
    }
   ],
   "source": [
    "full_dataset = MasksForMaskRCNNDataset(\n",
    "    images_root=DS_IMAGES,\n",
    "    masks_root=DS_MASKS,\n",
    "    transforms=Compose([\n",
    "        ToTensor(),\n",
    "        RandomHorizontalFlip(0.5),\n",
    "    ])\n",
    ")\n",
    "num_of_targets = 1\n",
    "print(f'num_of_targets = {num_of_targets}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "142cf22a",
   "metadata": {
    "cellId": "eam1tkh4we33jiixj8t46"
   },
   "outputs": [],
   "source": [
    "train_size = int(0.8 * len(full_dataset))\n",
    "val_size = len(full_dataset) - train_size\n",
    "\n",
    "dataset_train, dataset_val = random_split(full_dataset,\n",
    "                                          [train_size, val_size],\n",
    "                                          generator=torch.Generator().manual_seed(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "839f0590",
   "metadata": {
    "cellId": "47b0jkr0hu7uznuewj20p"
   },
   "outputs": [],
   "source": [
    "train_loader = DataLoader(\n",
    "    dataset_train,\n",
    "    batch_size=PARAMS['batch_size'],\n",
    "    shuffle=True,\n",
    "    pin_memory=True,\n",
    "    drop_last=True\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    dataset_val,\n",
    "    batch_size=PARAMS['batch_size'],\n",
    "    shuffle=True,\n",
    "    pin_memory=True,\n",
    "    drop_last=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c46c38ac",
   "metadata": {
    "cellId": "tii3mrngw7gi0heqnjzir"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 383, 456])\n",
      "torch.Size([1, 1, 4])\n",
      "torch.Size([1, 1])\n",
      "torch.Size([1, 1])\n",
      "torch.Size([1, 1])\n",
      "torch.Size([1, 1])\n"
     ]
    }
   ],
   "source": [
    "image, targets = next(iter(train_loader))\n",
    "\n",
    "print( image.shape )\n",
    "print( targets['boxes'].shape )\n",
    "print( targets['labels'].shape )\n",
    "print( targets['image_id'].shape )\n",
    "print( targets['area'].shape )\n",
    "print( targets['iscrowd'].shape )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "efa59496",
   "metadata": {
    "cellId": "3nww1emr9a6n4r82fa53vj"
   },
   "outputs": [],
   "source": [
    "model = build_maskrsnn_model(num_of_targets, PARAMS['pretrained'])\n",
    "\n",
    "params = [p for p in model.parameters() if p.requires_grad]\n",
    "optimizer = torch.optim.SGD(params, lr=PARAMS['lr'], momentum=PARAMS['momentum'], weight_decay=PARAMS['weight_decay'])\n",
    "lr_scheduler_global = torch.optim.lr_scheduler.StepLR(optimizer, step_size=PARAMS['step_size'], gamma=PARAMS['gamma'])\n",
    "\n",
    "epoch_start = 0\n",
    "best_accuracy = 0\n",
    "best_state = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "83f5867f",
   "metadata": {
    "cellId": "hmz9wi2duwu11ihuuu8m9u"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device = cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth\" to /tmp/xdg_cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "979c05f7b1334fad87c241ff5f61e8c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=178090079.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 1/136 [00:05<12:16,  5.46s/it]../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [9,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [11,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [14,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [18,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [20,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [26,0,0] Assertion `t >= 0 && t < n_classes` failed.\n",
      "../aten/src/ATen/native/cuda/Loss.cu:271: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-ab096cc10900>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscaler\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m             \u001b[0mloss_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m             \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mloss_dict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, images, targets)\u001b[0m\n\u001b[1;32m    103\u001b[0m             \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"0\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m         \u001b[0mproposals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproposal_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrpn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m         \u001b[0mdetections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdetector_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroi_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproposals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    106\u001b[0m         \u001b[0mdetections\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpostprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdetections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_image_sizes\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[operator]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1128\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1129\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1131\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, features, proposals, image_shapes, targets)\u001b[0m\n\u001b[1;32m    770\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mregression_targets\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    771\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"regression_targets cannot be None\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 772\u001b[0;31m             \u001b[0mloss_classifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_box_reg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfastrcnn_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclass_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_regression\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mregression_targets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    773\u001b[0m             \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"loss_classifier\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss_classifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"loss_box_reg\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mloss_box_reg\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    774\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py\u001b[0m in \u001b[0;36mfastrcnn_loss\u001b[0;34m(class_logits, box_regression, labels, regression_targets)\u001b[0m\n\u001b[1;32m     34\u001b[0m     \u001b[0;31m# the corresponding ground truth labels, to be used with\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m     \u001b[0;31m# advanced indexing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m     \u001b[0msampled_pos_inds_subset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabels\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     37\u001b[0m     \u001b[0mlabels_pos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msampled_pos_inds_subset\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     38\u001b[0m     \u001b[0mN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclass_logits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1."
     ]
    }
   ],
   "source": [
    "#!g2.mig\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print( f'device = {device}' )\n",
    "\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "for epoch in range(PARAMS['epochs']):\n",
    "    pbar = tqdm(total=len(train_loader.dataset))\n",
    "\n",
    "    acc_loss_value = 0\n",
    "    acc_loss_classifier = 0\n",
    "    acc_loss_box_reg = 0\n",
    "    acc_loss_mask = 0\n",
    "    acc_loss_objectness = 0\n",
    "    acc_loss_rpn_box_reg = 0\n",
    "\n",
    "    model.train()\n",
    "    scaler = None\n",
    "    lr_scheduler = None\n",
    "    if epoch == 0:\n",
    "        warmup_factor = 1.0 / 1000\n",
    "        warmup_iters = min(1000, len(train_loader) - 1)\n",
    "        lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
    "            optimizer, start_factor=warmup_factor, total_iters=warmup_iters\n",
    "        )\n",
    "    for images, targets in train_loader:\n",
    "        pbar.update(len(images))\n",
    "\n",
    "        images = list(image.to(device) for image in images)\n",
    "        # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]\n",
    "        device_target = {}\n",
    "        for k in ( 'boxes', 'labels', 'masks', 'image_id', 'area', 'iscrowd' ):\n",
    "            device_target[k] = targets[k].to(device)\n",
    "        targets = [{\n",
    "            'boxes': device_target['boxes'][0,:,:],\n",
    "            'labels': device_target['labels'][0,:],\n",
    "            'masks': device_target['masks'][0,:],\n",
    "            'image_id': device_target['image_id'][0,:],\n",
    "            'area': device_target['area'][0,:],\n",
    "            'iscrowd': device_target['iscrowd'][0,:],\n",
    "        }]\n",
    "        #\n",
    "        with torch.cuda.amp.autocast(enabled=scaler is not None):\n",
    "            loss_dict = model(images, targets)\n",
    "            losses = sum(loss for loss in loss_dict.values())\n",
    "\n",
    "        loss_value = losses.item()\n",
    "        acc_loss_value += loss_value\n",
    "        acc_loss_classifier += loss_dict['loss_classifier'].item()\n",
    "        acc_loss_box_reg += loss_dict['loss_box_reg'].item()\n",
    "        acc_loss_mask += loss_dict['loss_mask'].item()\n",
    "        acc_loss_objectness += loss_dict['loss_objectness'].item()\n",
    "        acc_loss_rpn_box_reg += loss_dict['loss_rpn_box_reg'].item()\n",
    "\n",
    "        if not math.isfinite(loss_value):\n",
    "            print(f\"Loss is {loss_value}, stopping training\")\n",
    "            sys.exit(1)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        if scaler is not None:\n",
    "            scaler.scale(losses).backward()\n",
    "            scaler.step(optimizer)\n",
    "            scaler.update()\n",
    "        else:\n",
    "            losses.backward()\n",
    "            optimizer.step()\n",
    "        if lr_scheduler is not None:\n",
    "            lr_scheduler.step()\n",
    "\n",
    "    pbar.close()\n",
    "\n",
    "    current_lr = optimizer.param_groups[0][\"lr\"]\n",
    "    lr_scheduler_global.step()\n",
    "\n",
    "    model.eval()\n",
    "    pbar = tqdm(total=len(val_loader.dataset))\n",
    "    # iou_types = \"segm\"  # \"segm\", \"bbox\", \"keypoints\"\n",
    "    segmPredicted = []\n",
    "    bboxPredicted = []\n",
    "    for images, targets in val_loader:\n",
    "        pbar.update(len(images))\n",
    "        images = list(img.to(device) for img in images)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.synchronize()\n",
    "        predictions = model(images)\n",
    "        accumulate_metrics(segmPredicted, bboxPredicted, targets, predictions)\n",
    "\n",
    "    metrics = compute_metrics(val_loader.dataset.coco, segmPredicted, bboxPredicted)\n",
    "    pbar.close()\n",
    "\n",
    "    print(\"Epoch: {0}; lr={1:.4f}; loss={2:.4f}; loss_mask={3:.4f}; cls={4:.4f}; box={5:.4f}; segm.mAP={6:.3f}; bbox.mAP={6:.3f}\"\n",
    "          .format(epoch, current_lr, acc_loss_value, acc_loss_mask, acc_loss_classifier, acc_loss_box_reg, metrics['segm']['mAP'], metrics['bbox']['mAP']))\n",
    "\n",
    "    metrics_summary = {\n",
    "        'current_lr': current_lr,\n",
    "        'acc_loss_value': acc_loss_value,\n",
    "        'acc_loss_mask': acc_loss_mask,\n",
    "        'acc_loss_classifier': acc_loss_classifier,\n",
    "        'acc_loss_box_reg': acc_loss_box_reg,\n",
    "        'acc_loss_objectness': acc_loss_objectness,\n",
    "        'acc_loss_rpn_box_reg': acc_loss_rpn_box_reg,\n",
    "    }\n",
    "    for mtype in ('segm', 'bbox'):\n",
    "        for metric_key in coco_metric_names.keys():\n",
    "            metrics_summary[mtype+'.'+metric_key] = metrics[mtype][metric_key]\n",
    "    # mlflow.log_metrics(metrics_summary)\n",
    "\n",
    "#     mlflow.pytorch.log_state_dict(artifact_path=\"checkpoint\", state_dict={\n",
    "#         \"model\": model.state_dict(),\n",
    "#         \"optimizer\": optimizer.state_dict(),\n",
    "#         \"epoch\": epoch,\n",
    "#         \"best_accuracy\": metrics['segm']['mAP']\n",
    "#     })\n",
    "    if metrics['segm']['mAP'] > best_accuracy:\n",
    "        best_accuracy = metrics['segm']['mAP']\n",
    "        best_state = copy.deepcopy(model.state_dict())\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a926cacf",
   "metadata": {
    "cellId": "74qvm7e2emcpjd0k860d6"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "notebookId": "f1223451-1595-404f-8bde-8aa2570af0fe",
  "notebookPath": "demo-ml-pennfudanped/train_model.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
